Skip to content

Conversation

@kausikmaiti
Copy link

@kausikmaiti kausikmaiti commented Dec 3, 2025

Added support for Column Major C [Bias] in GEMM.

Following changes have been made.

  1. XE_LOAD_2D or XE_LOAD_2D_TRANSPOSE used in TiledCopyC based on C [Bias] layout.
  2. Static assertion check on C [Bias] layout in xe_builder.inl has been removed.
  3. Reference output generation function GemmComplex has been fixed.
  4. 00_bmg_gemm.cpp example test has been updated accordingly.

To do:

  1. Python test to be added.

@kausikmaiti kausikmaiti marked this pull request as draft December 3, 2025 17:37
@kausikmaiti kausikmaiti requested a review from petercad December 3, 2025 17:38
@kausikmaiti kausikmaiti force-pushed the column_major branch 2 times, most recently from b0e4187 to 31cdc64 Compare December 9, 2025 14:28
@kausikmaiti kausikmaiti changed the title [WIP] Support Column Major Bias [C] Added support for Column Major C [Bias] in GEMM Dec 9, 2025
@kausikmaiti kausikmaiti marked this pull request as ready for review December 9, 2025 14:45
std::cout << "\n\nRunning BMG GEMM with bfloat16, RowMajor Bias and bfloat16, RowMajor Output" << std::endl << std::flush;
test_bmg_gemm<bfloat16_t, cutlass::layout::RowMajor, bfloat16_t>(options, hw_info);
std::cout << "\n\nRunning BMG GEMM with bfloat16, ColumnMajor Bias and bfloat16, RowMajor Output" << std::endl << std::flush;
test_bmg_gemm<bfloat16_t, cutlass::layout::ColumnMajor, bfloat16_t>(options, hw_info);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the 00_bmg_gemm example is the right place to test all these options -- as the very first example, I think it should probably be the simplest one. These feels like something that belongs as a test, or maybe we could have a separate 00_bmg_gemm_bias or 00_bmg_gemm_with_beta executable?


using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, DefaultCopyOpG2R>;
constexpr bool IsColMajorC = cutlass::gemm::detail::is_major<0, StrideC>();
using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, std::conditional_t<IsColMajorC, CopyOpG2RTransposed, DefaultCopyOpG2R>>;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest refactoring this a bit to make the logic/naming clearer:

using DefaultCopyOpR2GNontranspose = XE_STORE_2D<CopyBitsD, cute::gcd(8, get<0>(EpilogueTile{})), cute::gcd(512 / CopyBitsD, get<1>(EpilogueTile{}))>;
using DefaultCopyOpR2GTranspose = XE_LOAD_2D_TRANSPOSE<CopyBitsCTranspose, cute::gcd(512 / CopyBitsC, get<1>(EpilogueTile{})), cute::gcd(8 / Sub32BitFactor, get<0>(EpilogueTile{}))>;

using DefaultCopyOpR2G = conditional_t<IsColMajorC, DefaultCopyOpR2GTranspose, DefaultCopyOpR2GNontranspose>;
using ActualGmemTiledCopyC = replace_void_t<CopyOpG2R, DefaultCopyOpG2R>;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants